# 数据集
import contextlib
import hashlib
import math
import os
import glob
import random
import traceback
from itertools import repeat
from multiprocessing.pool import Pool, ThreadPool

import cv2
import numpy as np
from pathlib import Path

import psutil
import torch
from PIL import Image, ExifTags, ImageOps
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm import tqdm

from common.utils import segments2boxes, xyxy2xywh, xywhn2xyxy, xyn2xy, copy_paste, random_perspective, xyxy2xywhn, \
    augment_hsv, mixup, letterbox, colorstr
from constants.constants import IMG_FORMATS, WORLD_SIZE

orientation = next(k for k, v in ExifTags.TAGS.items() if v == 'Orientation')

NUM_THREADS = min(8, max(1, os.cpu_count() - 1))
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}"
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1))

def get_hash(paths):
    """通过合并文件或目录路径的大小和路径，为它们生成一个SHA256哈希值。"""
    size = sum(os.path.getsize(p) for p in paths if os.path.exists(p))  # sizes
    h = hashlib.sha256(str(size).encode())  # hash sizes
    h.update("".join(paths).encode())  # hash paths
    return h.hexdigest()  # return hash

def exif_size(img):
    """根据EXIF方向，返回已校正的PIL图像大小（宽度，高度）。"""
    img_size = img.size  # (width, height)
    with contextlib.suppress(Exception):
        rotation = dict(img._getexif().items())[orientation]
        if rotation in [6, 8]:  # rotation 270 or 90
            img_size = (img_size[1], img_size[0])

    return img_size


class Albumentations:
    """Provides optional data augmentation for YOLOv5 using Albumentations library if installed."""

    def __init__(self, size=640):
        """Initializes Albumentations class for optional data augmentation in YOLOv5 with specified input size."""
        self.transform = None
        prefix = colorstr("albumentations: ")
        try:
            import albumentations as A

            check_version(A.__version__, "1.0.3", hard=True)  # version requirement

            T = [
                A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0),
                A.Blur(p=0.01),
                A.MedianBlur(p=0.01),
                A.ToGray(p=0.01),
                A.CLAHE(p=0.01),
                A.RandomBrightnessContrast(p=0.0),
                A.RandomGamma(p=0.0),
                A.ImageCompression(quality_lower=75, p=0.0),
            ]  # transforms
            self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))

            print(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
        except ImportError:  # package not installed, skip
            pass
        except Exception as e:
            print(f"{prefix}{e}")

    def __call__(self, im, labels, p=1.0):
        """Applies transformations to an image and labels with probability `p`, returning updated image and labels."""
        if self.transform and random.random() < p:
            new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0])  # transformed
            im, labels = new["image"], np.array([[c, *b] for c, b in zip(new["class_labels"], new["bboxes"])])
        return im, labels


class LoadImagesAndLabels(Dataset):
    """
    加载图片数据和标签
    """

    def __init__(self,
        data_paths,
        img_size=640,
        batch_size=16,
        augment=False,
        hyp=None,
        rect=False,
        image_weights=False,
        cache_images=False,
        single_cls=False,
        stride=32,
        pad=0.0,
        min_items=0,
        prefix="",
        rank=-1,
        seed = 0,
    ):
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.image_weights = image_weights
        self.rect = False if image_weights else rect
        self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
        self.mosaic_border = [-img_size // 2, -img_size // 2]  # 限定范围
        self.stride = stride  # 下采样总值
        self.data_paths = data_paths
        self.albumentations = Albumentations(size=img_size) if augment else None

        self.img_files_path = self.get_images_path(data_paths)

        self.label_files_path = self.get_labels_path(self.img_files_path)
        # 获取标签缓存信息
        labels_cache= self.get_labels_cache(prefix,augment)

        # 读取缓存信息
        [labels_cache.pop(k) for k in ("hash", "version", "msgs")]  # remove items
        labels, shapes, self.segments = zip(*labels_cache.values())
        nl = len(np.concatenate(labels, 0))  # number of labels
        assert nl > 0 or not augment, f"{prefix}All labels empty in {self.label_files_path}"
        self.labels = list(labels)
        self.shapes = np.array(shapes)
        self.img_files_path = list(labels_cache.keys())  # update
        self.label_files_path = self.get_labels_path(self.img_files_path)  # update
        total_number = len(self.img_files_path)
        # 筛选数据
        if min_items:
            include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
            print(f"{prefix}{total_number - len(include)}/{total_number} images filtered from dataset")
            self.change_info(include)

        # Create indices
        image_number = len(self.shapes)
        batch_indexes = np.floor(np.arange(image_number) / batch_size).astype(int)
        batch_number = batch_indexes[-1] + 1
        self.batch_indexes = batch_indexes
        self.image_number = image_number
        self.indices = np.arange(image_number)
        # if rank > -1:  # DDP indices (see: SmartDistributedSampler)
        #     # force each rank (i.e. GPU process) to sample the same subset of data on every epoch
        #     self.indices = self.indices[np.random.RandomState(seed=seed).permutation(image_number) % WORLD_SIZE == RANK]

        # Update labels
        include_class = []  # filter labels to include only these classes (optional)
        self.segments = list(self.segments)
        include_class_array = np.array(include_class).reshape(1, -1)
        for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
            if include_class:
                j = (label[:, 0:1] == include_class_array).any(1)
                self.labels[i] = label[j]
                if segment:
                    self.segments[i] = [segment[idx] for idx, elem in enumerate(j) if elem]
            if single_cls:  # single-class training, merge all classes into 0
                self.labels[i][:, 0] = 0

        # Rectangular Training
        if self.rect:
            self.batch_shapes = self.rectangular_training(batch_number, img_size, stride, pad)

        # Cache images into RAM/disk for faster training
        if cache_images == "ram" and not self.check_cache_ram(prefix=prefix):
            cache_images = False
        self.ims = [None] * image_number
        self.npy_files = [Path(f).with_suffix(".npy") for f in self.img_files_path]
        if cache_images:
            b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
            self.im_hw0, self.im_hw = [None] * image_number, [None] * image_number
            fcn = self.cache_images_to_disk if cache_images == "disk" else self.load_image
            with ThreadPool(NUM_THREADS) as pool:
                results = pool.imap(lambda i: (i, fcn(i)), self.indices)
                pbar = tqdm(results, total=len(self.indices), bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
                for i, x in pbar:
                    if cache_images == "disk":
                        b += self.npy_files[i].stat().st_size
                    else:  # 'ram'
                        self.ims[i], self.im_hw0[i], self.im_hw[
                            i] = x  # im, hw_orig, hw_resized = load_image(self, i)
                        b += self.ims[i].nbytes * WORLD_SIZE
                    pbar.desc = f"{prefix}Caching images ({b / gb:.1f}GB {cache_images})"
                pbar.close()

    def __len__(self):
        return len(self.img_files_path)

    # def __iter__(self):
    #     self.count = -1
    #     print('ran dataset iter')
    #     #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
    #     return self

    def __getitem__(self, index):
        """Fetches the dataset item at the given index, considering linear, shuffled, or weighted sampling."""
        index = self.indices[index]  # linear, shuffled, or image_weights

        hyp = self.hyp
        if mosaic := self.mosaic and random.random() < hyp["mosaic"]:
            # Load mosaic
            img, labels = self.load_mosaic(index)
            shapes = None

            # MixUp augmentation
            if random.random() < hyp["mixup"]:
                img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices)))

        else:
            # Load image
            img, (h0, w0), (h, w) = self.load_image(index)

            # Letterbox
            shape = self.batch_shapes[self.batch_indexes[index]] if self.rect else self.img_size  # final letterboxed shape
            img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
            shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling

            labels = self.labels[index].copy()
            if labels.size:  # normalized xywh to pixel xyxy format
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

            if self.augment:
                img, labels = random_perspective(
                    img,
                    labels,
                    degrees=hyp["degrees"],
                    translate=hyp["translate"],
                    scale=hyp["scale"],
                    shear=hyp["shear"],
                    perspective=hyp["perspective"],
                )

        nl = len(labels)  # number of labels
        if nl:
            labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3)

        if self.augment:
            # Albumentations
            img, labels = self.albumentations(img, labels)
            nl = len(labels)  # update after albumentations

            # HSV color-space
            augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])

            # Flip up-down
            if random.random() < hyp["flipud"]:
                img = np.flipud(img)
                if nl:
                    labels[:, 2] = 1 - labels[:, 2]

            # Flip left-right
            if random.random() < hyp["fliplr"]:
                img = np.fliplr(img)
                if nl:
                    labels[:, 1] = 1 - labels[:, 1]

            # Cutouts
            # labels = cutout(img, labels, p=0.5)
            # nl = len(labels)  # update after cutout

        labels_out = torch.zeros((nl, 6))
        if nl:
            labels_out[:, 1:] = torch.from_numpy(labels)

        # Convert
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img)

        return torch.from_numpy(img), labels_out, self.img_files_path[index], shapes


    @staticmethod
    def get_images_path(data_paths):
        """
        获取图片路径列表
        """
        try:
            file_paths = []
            for data_path in data_paths if isinstance(data_paths, list) else [data_paths]:
                # 标准化路径
                standard_data_path = Path(data_path)
                # 文件夹处理
                if standard_data_path.is_dir():
                    file_paths += glob.iglob(str(standard_data_path / "**" / "*.*"), recursive=True)
                # 路径文件判断  一个文件中按行写入图片路径
                elif standard_data_path.is_file():
                    # 构建单个文件路径
                    with open(standard_data_path, 'r') as file:
                        lines = file.read().strip().splitlines()
                        parent = str(standard_data_path.parent) + os.sep
                        file_paths += [info.replace('./', parent) if info.startswith('./') else info for info in lines]
                else:
                    raise Exception(f'{standard_data_path} does not exist')

            # 过滤掉图片非图片路径
            img_files = sorted(x.replace("/", os.sep) for x in file_paths if x.split(".")[-1].lower() in IMG_FORMATS)
            return img_files
        except Exception as e:
            raise Exception(f"加载数据失败{data_paths}")

    @staticmethod
    def get_labels_path(img_files_path):
        """
        通过将 `/images/` 替换为 `/labels/`，并将扩展名替换为 `.txt`，从相应的图像文件路径生成标签文件路径。
        """
        # 默认标签
        # /images/, /labels/
        images_str, labels_str = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}"
        # 标签路径
        label_files = [labels_str.join(x.rsplit(images_str, 1)).rsplit(".", 1)[0] + ".txt" for x in img_files_path]
        return label_files

    def get_labels_cache(self, prefix,augment):
        """
        获取标签缓存
        """
        # 可以设置缓存，再训练就不用一个个读了 标签缓存
        cache_path = Path(self.label_files_path[0]).parent.with_suffix(".cache")
        exists = True
        if os.path.isfile(cache_path):
            cache = np.load(cache_path, allow_pickle=True).item()
            if cache['hash'] != get_hash(self.label_files_path + self.img_files_path):
                cache = self.cache_labels(cache_path, prefix)
                exists = False
        else:
            cache = self.cache_labels(cache_path, prefix)
            exists=False

        # 显示标签缓存
        nf, nm, ne, nc, total_number = cache.pop("results")  # found, missing, empty, corrupt, total
        if exists and LOCAL_RANK in {-1, 0}:
            d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
            tqdm(None, desc=prefix + d, total=total_number, initial=total_number, bar_format=TQDM_BAR_FORMAT)  # display cache results
            if cache["msgs"]:
                print("\n".join(cache["msgs"]))  # display warnings
        assert nf > 0 or not augment, f"{prefix}No labels found in {cache_path}"
        return cache

    def cache_labels(self, path=Path("./labels.cache"), prefix=""):
        """
        缓存数据集标签，验证图像，读取形状，并跟踪数据集完整性。
        """
        cache_info = {}
        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
        # 构建进度条前面的文字说明
        desc = f"{prefix}Scanning images path {path.parent / path.stem}..."
        with Pool(NUM_THREADS) as pool:
            pbar = tqdm(
                pool.imap(self.verify_image_label, zip(self.img_files_path, self.label_files_path, repeat(prefix))),
                desc=desc,
                total=len(self.img_files_path),
                bar_format=TQDM_BAR_FORMAT,
            )
            for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
                nm += nm_f
                nf += nf_f
                ne += ne_f
                nc += nc_f
                if im_file:
                    cache_info[im_file] = [lb, shape, segments]
                if msg:
                    msgs.append(msg)
                pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"

        pbar.close()
        if msgs:
            print("\n".join(msgs))
        if nf == 0:
            print(f"{prefix}WARNING ⚠️ No labels found in {path}")
        cache_info["hash"] = get_hash(self.img_files_path + self.label_files_path)
        cache_info["results"] = (nf, nm, ne, nc, len(self.img_files_path))
        cache_info["msgs"] = msgs  # warnings
        cache_info["version"] = 0.6
        try:
            np.save(path, cache_info)  # save cache for next time
            path.with_suffix(".cache.npy").rename(path)  # remove .npy suffix
            print(f"{prefix}New cache created: {path}")
        except Exception as e:
            print(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}")  # not writeable

        return cache_info

    @staticmethod
    def verify_image_label(args):
        """验证单个图像-标签对，确保图像格式、大小和标签值合法。"""
        image_file_path, label_file_path, prefix = args
        nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, "", []  # number (missing, found, empty, corrupt), message, segments
        try:
            # 检查图像文件是否完整、未损坏
            im = Image.open(image_file_path)
            im.verify()
            shape = exif_size(im)  # image size
            assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
            assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
            if im.format.lower() in ("jpg", "jpeg"):
                with open(image_file_path, "rb") as f:
                    f.seek(-2, 2)
                    if f.read() != b"\xff\xd9":  # corrupt JPEG
                        ImageOps.exif_transpose(Image.open(image_file_path)).save(image_file_path, "JPEG", subsampling=0, quality=100)
                        msg = f"{prefix}WARNING ⚠️ {image_file_path}: corrupt JPEG restored and saved"

            # 验证标签
            if os.path.isfile(label_file_path):
                nf = 1  # label found
                with open(label_file_path) as f:
                    lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
                    if any(len(x) > 6 for x in lb):  # is segment
                        classes = np.array([x[0] for x in lb], dtype=np.float32)
                        segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb]  # (cls, xy1...)
                        lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1)  # (cls, xywh)
                    lb = np.array(lb, dtype=np.float32)
                if nl := len(lb): #---> nl = len(lb) if nl:
                    assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
                    assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
                    assert (lb[:,1:] <= 1).all(), f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
                    _, i = np.unique(lb, axis=0, return_index=True)
                    if len(i) < nl:  # duplicate row check
                        lb = lb[i]  # remove duplicates
                        if segments:
                            segments = [segments[x] for x in i]
                        msg = f"{prefix}WARNING ⚠️ {image_file_path}: {nl - len(i)} duplicate labels removed"
                else:
                    ne = 1  # label empty
                    lb = np.zeros((0, 5), dtype=np.float32)
            else:
                nm = 1  # label missing
                lb = np.zeros((0, 5), dtype=np.float32)
            return image_file_path, lb, shape, segments, nm, nf, ne, nc, msg
        except Exception as e:
            nc = 1
            msg = f"{prefix}WARNING ⚠️ {image_file_path}: ignoring corrupt image/label: {e}"
            return [None, None, None, None, nm, nf, ne, nc, msg]


    def change_info(self,indexes):
        """
        改变数据信息
        """
        self.img_files_path = [self.img_files_path[i] for i in indexes]
        self.label_files_path = [self.label_files_path[i] for i in indexes]
        self.labels = [self.labels[i] for i in indexes]
        self.segments = [self.segments[i] for i in indexes]
        self.shapes = self.shapes[indexes]  # wh

    def rectangular_training(self,batch_number,img_size,stride,pad):
        # Sort by aspect ratio
        s = self.shapes  # wh
        ar = s[:, 1] / s[:, 0]  # aspect ratio
        irect = ar.argsort()
        self.change_info(irect)
        ar = ar[irect]

        # Set training image shapes
        shapes = [[1, 1]] * batch_number
        for i in range(batch_number):
            ari = ar[self.batch_indexes == i]
            mini, maxi = ari.min(), ari.max()
            if maxi < 1:
                shapes[i] = [maxi, 1]
            elif mini > 1:
                shapes[i] = [1, 1 / mini]

        batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
        return batch_shapes

    def check_cache_ram(self, safety_margin=0.1, prefix=""):
        """Checks if available RAM is sufficient for caching images, adjusting for a safety margin."""
        b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
        n = min(self.image_number, 30)  # extrapolate from 30 random images
        for _ in range(n):
            im = cv2.imread(random.choice(self.img_files_path))  # sample image
            ratio = self.img_size / max(im.shape[0], im.shape[1])  # max(h, w)  # ratio
            b += im.nbytes * ratio ** 2
        mem_required = b * self.image_number / n  # GB required to cache dataset into RAM
        mem = psutil.virtual_memory()
        cache = mem_required * (1 + safety_margin) < mem.available  # to cache or not to cache, that is the question
        if not cache:
            print(
                f"{prefix}{mem_required / gb:.1f}GB RAM required, "
                f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
                f"{'caching images ✅' if cache else 'not caching images ⚠️'}"
            )
        return cache

    def load_image(self, i):
        """
        Loads an image by index, returning the image, its original dimensions, and resized dimensions.

        Returns (im, original hw, resized hw)
        """
        im, f, fn = (
            self.ims[i],
            self.img_files_path[i],
            self.npy_files[i],
        )
        if im is None:  # not cached in RAM
            if fn.exists():  # load npy
                im = np.load(fn)
            else:  # read image
                im = cv2.imread(f)  # BGR
                assert im is not None, f"Image Not Found {f}"
            h0, w0 = im.shape[:2]  # orig hw
            r = self.img_size / max(h0, w0)  # ratio
            if r != 1:  # if sizes are not equal
                interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
                im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
            return im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
        return self.ims[i], self.im_hw0[i], self.im_hw[i]  # im, hw_original, hw_resized

    def cache_images_to_disk(self, i):
        """Saves an image to disk as an *.npy file for quicker loading, identified by index `i`."""
        f = self.npy_files[i]
        if not f.exists():
            np.save(f.as_posix(), cv2.imread(self.img_files_path[i]))

    def load_mosaic(self, index):
        """Loads a 4-image mosaic for YOLOv5, combining 1 selected and 3 random images, with labels and segments."""
        labels4, segments4 = [], []
        s = self.img_size
        yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border)  # mosaic center x, y
        indices = [index] + random.choices(self.indices, k=3)  # 3 additional image indices
        random.shuffle(indices)
        for i, index in enumerate(indices):
            # Load image
            img, _, (h, w) = self.load_image(index)

            # place img in img4
            if i == 0:  # top left
                img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
                x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
            elif i == 1:  # top right
                x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
                x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
            elif i == 2:  # bottom left
                x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
            elif i == 3:  # bottom right
                x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
                x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)

            img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
            padw = x1a - x1b
            padh = y1a - y1b

            # Labels
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
            if labels.size:
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh)  # normalized xywh to pixel xyxy format
                segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
            labels4.append(labels)
            segments4.extend(segments)

        # Concat/clip labels
        labels4 = np.concatenate(labels4, 0)
        for x in (labels4[:, 1:], *segments4):
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
        # img4, labels4 = replicate(img4, labels4)  # replicate

        # Augment
        img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp["copy_paste"])
        img4, labels4 = random_perspective(
            img4,
            labels4,
            segments4,
            degrees=self.hyp["degrees"],
            translate=self.hyp["translate"],
            scale=self.hyp["scale"],
            shear=self.hyp["shear"],
            perspective=self.hyp["perspective"],
            border=self.mosaic_border,
        )  # border to remove

        return img4, labels4

    def load_mosaic9(self, index):
        """Loads 1 image + 8 random images into a 9-image mosaic for augmented YOLOv5 training, returning labels and
        segments.
        """
        labels9, segments9 = [], []
        s = self.img_size
        indices = [index] + random.choices(self.indices, k=8)  # 8 additional image indices
        random.shuffle(indices)
        hp, wp = -1, -1  # height, width previous
        for i, index in enumerate(indices):
            # Load image
            img, _, (h, w) = self.load_image(index)

            # place img in img9
            if i == 0:  # center
                img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
                h0, w0 = h, w
                c = s, s, s + w, s + h  # xmin, ymin, xmax, ymax (base) coordinates
            elif i == 1:  # top
                c = s, s - h, s + w, s
            elif i == 2:  # top right
                c = s + wp, s - h, s + wp + w, s
            elif i == 3:  # right
                c = s + w0, s, s + w0 + w, s + h
            elif i == 4:  # bottom right
                c = s + w0, s + hp, s + w0 + w, s + hp + h
            elif i == 5:  # bottom
                c = s + w0 - w, s + h0, s + w0, s + h0 + h
            elif i == 6:  # bottom left
                c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
            elif i == 7:  # left
                c = s - w, s + h0 - h, s, s + h0
            elif i == 8:  # top left
                c = s - w, s + h0 - hp - h, s, s + h0 - hp

            padx, pady = c[:2]
            x1, y1, x2, y2 = (max(x, 0) for x in c)  # allocate coords

            # Labels
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
            if labels.size:
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady)  # normalized xywh to pixel xyxy format
                segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
            labels9.append(labels)
            segments9.extend(segments)

            # Image
            img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:]  # img9[ymin:ymax, xmin:xmax]
            hp, wp = h, w  # height, width previous

        # Offset
        yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border)  # mosaic center x, y
        img9 = img9[yc: yc + 2 * s, xc: xc + 2 * s]

        # Concat/clip labels
        labels9 = np.concatenate(labels9, 0)
        labels9[:, [1, 3]] -= xc
        labels9[:, [2, 4]] -= yc
        c = np.array([xc, yc])  # centers
        segments9 = [x - c for x in segments9]

        for x in (labels9[:, 1:], *segments9):
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
        # img9, labels9 = replicate(img9, labels9)  # replicate

        # Augment
        img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp["copy_paste"])
        img9, labels9 = random_perspective(
            img9,
            labels9,
            segments9,
            degrees=self.hyp["degrees"],
            translate=self.hyp["translate"],
            scale=self.hyp["scale"],
            shear=self.hyp["shear"],
            perspective=self.hyp["perspective"],
            border=self.mosaic_border,
        )  # border to remove

        return img9, labels9

    @staticmethod
    def collate_fn(batch):
        """Batches images, labels, paths, and shapes, assigning unique indices to targets in merged label tensor."""
        im, label, path, shapes = zip(*batch)  # transposed
        for i, lb in enumerate(label):
            lb[:, 0] = i  # add target image index for build_targets()
        return torch.stack(im, 0), torch.cat(label, 0), path, shapes

    @staticmethod
    def collate_fn4(batch):
        """Bundles a batch's data by quartering the number of shapes and paths, preparing it for model input."""
        im, label, path, shapes = zip(*batch)  # transposed
        n = len(shapes) // 4
        im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]

        ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
        wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
        s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]])  # scale
        for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
            i *= 4
            if random.random() < 0.5:
                im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode="bilinear", align_corners=False)[
                    0
                ].type(im[i].type())
                lb = label[i]
            else:
                im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
                lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
            im4.append(im1)
            label4.append(lb)

        for i, lb in enumerate(label4):
            lb[:, 0] = i  # add target image index for build_targets()

        return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4